{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_"
      },
      "source": [
        "# Parameter sharing in Haiku"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZFh4JGWqcAww"
      },
      "source": [
        "## Introduction\n",
        "In Haiku, parameter reuse is determined uniquely by *module instance names*, i.e., if a module instance has the same name as another module instance, they share parameters.\n",
        "\n",
        "Unless specified, module names are automatically determined by Haiku based on the module *class* name (following a pattern that was established in TensorFlow 1 with Sonnet V1). More in detail, module naming follows these rules:\n",
        "\n",
        "1. Module names are assigned when the module instance is *constructed*. Unless a module instance name is provided as an argument to the constructor, Haiku generates one from the current *module class name* (basically: `to_snake_case(CurrentClassName)`).\n",
        "2. If the module instance name doesn't end in a `_N` (where `N` is a number) and another module instance with the same name already exists, Haiku adds an incremental number to the end of the new module instance name (e.g. `module_1`).\n",
        "3. When two modules are nested (i.e., a module instance is constructed inside another module's class definition), then the inner module name will be prepended by the *outer module name* and, possibly (see the next point), the *outer module current method* being called. The constructor (i.e., `__init__`) is replaced by the tilde `~` symbol.\n",
        "4. If the calling method name is `__call__` this will be ignored (the method name will be prepended by the *outer module name* only).\n",
        "4. When there are multiple layers of nesting, the previous rule is applied at each level of nesting, and each inner module name is based on the module name and calling method name of the module immediately preceding the current module in the hierarchy of calls.\n",
        "\n",
        "Let's see how this works with a practical example."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YlHnYngGcAwx"
      },
      "source": [
        "## Flat modules (no nesting)\n",
        "This section covers parameter sharing when the modules are not nested."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "executionInfo": {
          "elapsed": 59,
          "status": "ok",
          "timestamp": 1694690936995,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "xnqV_aMTcAwx"
      },
      "outputs": [],
      "source": [
        "#@title Imports and accessory functions\n",
        "import functools\n",
        "import haiku as hk\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "\n",
        "\n",
        "def parameter_shapes(params):\n",
        "  \"\"\"Make printing parameters a little more readable.\"\"\"\n",
        "  return jax.tree.map(lambda p: p.shape, params)\n",
        "\n",
        "\n",
        "def transform_and_print_shapes(fn, x_shape=(2, 3)):\n",
        "  \"\"\"Print name and shape of the parameters.\"\"\"\n",
        "  rng = jax.random.PRNGKey(42)\n",
        "  x = jnp.ones(x_shape)\n",
        "\n",
        "  transformed_fn = hk.transform(fn)\n",
        "  params = transformed_fn.init(rng, x)\n",
        "  print('\\nThe name and shape of the parameters are:')\n",
        "  print(parameter_shapes(params))\n",
        "\n",
        "def assert_all_equal(params_1, params_2):\n",
        "  assert all(jax.tree.leaves(\n",
        "      jax.tree.map(lambda a, b: (a == b).all(), params_1, params_2)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "executionInfo": {
          "elapsed": 65,
          "status": "ok",
          "timestamp": 1694690970489,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "ebmr36pycAwy"
      },
      "outputs": [],
      "source": [
        "w_init = hk.initializers.TruncatedNormal(stddev=1)\n",
        "\n",
        "class SimpleModule(hk.Module):\n",
        "  \"\"\"A simple module class with one variable.\"\"\"\n",
        "\n",
        "  def __init__(self, output_channels, name=None):\n",
        "    super().__init__(name)\n",
        "    assert isinstance(output_channels, int)\n",
        "    self._output_channels = output_channels\n",
        "\n",
        "  def __call__(self, x):\n",
        "    w_shape = (x.shape[-1], self._output_channels)\n",
        "    w = hk.get_parameter(\"w\", w_shape, x.dtype, init=w_init)\n",
        "    return jnp.dot(x, w)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2zxaxRdPcAwy",
        "outputId": "4cec5f2c-7ad8-42c3-a082-22f4007912ec"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "The name assigned to \"simple\" is: \"simple_module\".\n",
            "\n",
            "The name and shape of the parameters are:\n",
            "{'simple_module': {'w': (3, 2)}}\n"
          ]
        }
      ],
      "source": [
        "def f(x):\n",
        "  # This instance will be named `a_simple_module`.\n",
        "  simple = SimpleModule(output_channels=2)\n",
        "  simple_out = simple(x)  # implicitly calls module_install.__call__()\n",
        "  print(f'The name assigned to \"simple\" is: \"{simple.module_name}\".')\n",
        "  return simple_out\n",
        "\n",
        "transform_and_print_shapes(f)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M9MqAvKCcAwz"
      },
      "source": [
        "Great! Here we see that indeed if we create a `SimpleModule` instance and do not specify a name, Haiku assigns to it the name `a_simple_module`. This is also reflected in the parameters associated to the module.\n",
        "\n",
        "What happens if we instantiate `SimpleModule` twice though? Does Haiku assign to both instances the same name?"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pcWOjMatcAwz",
        "outputId": "72a663cb-1d58-44d8-ec82-555a5c98acab"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "The name assigned to \"simple_one\" is: \"simple_module\".\n",
            "The name assigned to \"simple_two\" is: \"simple_module_1\".\n",
            "\n",
            "The name and shape of the parameters are:\n",
            "{'simple_module': {'w': (3, 2)}, 'simple_module_1': {'w': (3, 2)}}\n"
          ]
        }
      ],
      "source": [
        "def f(x):\n",
        "  # This instance will be named `a_simple_module`.\n",
        "  simple_one = SimpleModule(output_channels=2)\n",
        "  # This instance will be named `a_simple_module_1`.\n",
        "  simple_two = SimpleModule(output_channels=2)\n",
        "  first_out = simple_one(x)\n",
        "  second_out = simple_two(x)\n",
        "  print(f'The name assigned to \"simple_one\" is: \"{simple_one.module_name}\".')\n",
        "  print(f'The name assigned to \"simple_two\" is: \"{simple_two.module_name}\".')\n",
        "  return first_out + second_out\n",
        "\n",
        "transform_and_print_shapes(f)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JxqNp4QmcAwz"
      },
      "source": [
        "As expected Haiku is smart enough to differentiate the two instances and avoid accidental parameter sharing: the second instance is named `a_simple_module_1` and each instance has its own set of parameters. Good!\n",
        "\n",
        "But what if we wanted to share parameters? In this case, we would have to instantiate the module only once and *call* it multiple times. Let's see how this works:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NRDYdDTTcAw0",
        "outputId": "42cadf34-43c9-468d-c769-271c712fae01"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "The name assigned to \"simple_one\" is: \"simple_module\".\n",
            "\n",
            "The name and shape of the parameters are:\n",
            "{'simple_module': {'w': (3, 2)}}\n"
          ]
        }
      ],
      "source": [
        "def f(x):\n",
        "  # This instance will be named `a_simple_module`.\n",
        "  simple_one = SimpleModule(output_channels=2)\n",
        "  first_out = simple_one(x)\n",
        "  second_out = simple_one(x)  # share parameters w/ previous call\n",
        "  print(f'The name assigned to \"simple_one\" is: \"{simple_one.module_name}\".')\n",
        "  return first_out + second_out\n",
        "\n",
        "transform_and_print_shapes(f)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HmvdP78ncAw0"
      },
      "source": [
        "## Nested modules\n",
        "In this section we'll see what happens when we nest one `hk.Module` into another."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "L5NOMmJhcAw0"
      },
      "outputs": [],
      "source": [
        "class NestedModule(hk.Module):\n",
        "  \"\"\"A module class with a nested module created in the constructor.\"\"\"\n",
        "\n",
        "  def __init__(self, output_channels, name=None):\n",
        "    super().__init__(name)\n",
        "    assert isinstance(output_channels, int)\n",
        "    self._output_channels = output_channels\n",
        "    self.inner_simple = SimpleModule(self._output_channels)\n",
        "\n",
        "  def __call__(self, x):\n",
        "    w_shape = (x.shape[-1], self._output_channels)\n",
        "    # Another variable that is also called `w`.\n",
        "    w = hk.get_parameter(\"w\", w_shape, x.dtype, init=w_init)\n",
        "    return jnp.dot(x, w) + self.inner_simple(x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kl_XzFKWcAw0",
        "outputId": "ea278536-87dd-456d-e735-73ee5bae4dcc"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "The name assigned to outer module (i.e., \"nested\") is: \"nested_module\".\n",
            "The name assigned to the inner module (i.e., inside \"nested\") is: \"nested_module/~/simple_module\".\n",
            "\n",
            "The name and shape of the parameters are:\n",
            "{'nested_module': {'w': (3, 2)}, 'nested_module/~/simple_module': {'w': (3, 2)}}\n"
          ]
        }
      ],
      "source": [
        "def f(x):\n",
        "  # This will be named `a_nested_module` and the SimpleModule instance created\n",
        "  # inside it will be named `a_nested_module/a_simple_module`.\n",
        "  nested = NestedModule(output_channels=2)\n",
        "  nested_out = nested(x)\n",
        "  print('The name assigned to outer module (i.e., \"nested\") is: '\n",
        "        f'\"{nested.module_name}\".')\n",
        "  print('The name assigned to the inner module (i.e., inside \"nested\") is: \"'\n",
        "        f'{nested.inner_simple.module_name}\".')\n",
        "  return nested_out\n",
        "\n",
        "transform_and_print_shapes(f)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SUc0H672cAw0"
      },
      "source": [
        "As expected, the inner module name depends on: (a) the outer module name; and (b) the outer module's method being called.\n",
        "\n",
        "Note also how the outer module's constructor name `__init__` is replaced by a `~` in the parameter names. If the inner module instance was created inside the `__call__` method of the outer module, the inner module instance name would have been `'a_nested_module/a_simple_module'`.\n",
        "\n",
        "In this example we defined all the modules from scratch, but the same holds for any of the modules and networks defined in Haiku, e.g., `hk.Linear`, `hk.nets.MLP`, ... . If you are curious, see what happens if you assign to `self.inner_simple` an instance of `hk.Linear` instead of `SimpleModule`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "78OZmbeZcAw0"
      },
      "source": [
        "Let's try now multiple levels of nesting:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jYNiChGjcAw0"
      },
      "outputs": [],
      "source": [
        "class TwiceNestedModule(hk.Module):\n",
        "  \"\"\"A module class with a nested module containing a nested module.\"\"\"\n",
        "\n",
        "  def __init__(self, output_channels, name=None):\n",
        "    super().__init__(name)\n",
        "    assert isinstance(output_channels, int)\n",
        "    self._output_channels = output_channels\n",
        "    self.inner_nested = NestedModule(self._output_channels)\n",
        "\n",
        "  def __call__(self, x):\n",
        "    w_shape = (x.shape[-1], self._output_channels)\n",
        "    w = hk.get_parameter(\"w\", w_shape, x.dtype, init=w_init)\n",
        "    return jnp.dot(x, w) + self.inner_nested(x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mODRWoM6cAw0",
        "outputId": "b17ee674-51a1-423a-8eed-0f14bcb8e924"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "The name assigned to the most outer class is: \"twice_nested_module\".\n",
            "The name assigned to the module inside \"double_nested\" is: \"twice_nested_module/~/nested_module\".\n",
            "The name assigned to the module inside it is \"twice_nested_module/~/nested_module/~/simple_module\".\n",
            "\n",
            "The name and shape of the parameters are:\n",
            "{'twice_nested_module': {'w': (3, 2)}, 'twice_nested_module/~/nested_module': {'w': (3, 2)}, 'twice_nested_module/~/nested_module/~/simple_module': {'w': (3, 2)}}\n"
          ]
        }
      ],
      "source": [
        "def f(x):\n",
        "  \"\"\"Create the module instances and inspect their names.\"\"\"\n",
        "  # Instantiate a NestedModule instance. This will be named `a_nested_module`.\n",
        "  # The SimpleModule instance created inside it will be named\n",
        "  # a_nested_module/a_simple_module`.\n",
        "  outer = TwiceNestedModule(output_channels=2)\n",
        "  outer_out = outer(x)\n",
        "  print(f'The name assigned to the most outer class is: \"{outer.module_name}\".')\n",
        "  print('The name assigned to the module inside \"double_nested\" is: \"'\n",
        "        f'{outer.inner_nested.module_name}\".')\n",
        "  print('The name assigned to the module inside it is \"'\n",
        "        f'{outer.inner_nested.inner_simple.module_name}\".')\n",
        "  return outer_out\n",
        "\n",
        "transform_and_print_shapes(f)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qnfJMYbxcAw0"
      },
      "source": [
        "Great, this also works as expected: the full hierarchy of module names and calls is reflected in the inner module names."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tMWh53qrcAw0"
      },
      "source": [
        "## Multitransform: merge the parameters without sharing them"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4T-42hTvcAw1"
      },
      "source": [
        "Sometimes when we have multiple transformed functions it can be convenient to merge all the parameters in a unique structure, to reduce the number of dictionaries we have to store and pass around. It can be the case though that some of these functions instantiate the same modules, and we want to make sure that their parameters don't get shared accidentally.\n",
        "\n",
        "`hk.multi_transform` comes to rescue in this case, and merges the parameters in a unique dictionary making sure that duplicated parameters are renamed to avoid accidental sharing."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "871V810NcAw1",
        "outputId": "2cdd6471-aa87-4e04-c114-cf6c2b268bb1"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "f parameters: {'linear': {'b': (40,), 'w': (2, 40)}, 'simple_module': {'w': (3, 2)}}\n",
            "g parameters: {'simple_module': {'w': (3, 2)}}\n",
            "\n",
            "The name and shape of the multi-transform parameters are:\n",
            " {'linear': {'b': (40,), 'w': (2, 40)}, 'simple_module': {'w': (3, 2)}, 'simple_module_1': {'w': (3, 2)}}\n"
          ]
        }
      ],
      "source": [
        "def f(x):\n",
        "  \"\"\"A SimpleModule followed by a Linear layer.\"\"\"\n",
        "  module_instance = SimpleModule(output_channels=2)\n",
        "  out = module_instance(x)\n",
        "  linear = hk.Linear(40)\n",
        "  return linear(out)\n",
        "\n",
        "def g(x):\n",
        "  \"\"\"A SimpleModule followed by an MLP.\"\"\"\n",
        "  module_instance = SimpleModule(output_channels=2)\n",
        "  return module_instance(x) * 2  # twice\n",
        "\n",
        "# Transform both functions, and print their respective parameter shapes.\n",
        "rng = jax.random.PRNGKey(42)\n",
        "x = jnp.ones((2, 3))\n",
        "transformed_f = hk.transform(f)\n",
        "params_f = transformed_f.init(rng, x)\n",
        "transformed_g = hk.transform(g)\n",
        "params_g = transformed_g.init(rng, x)\n",
        "print('f parameters:', parameter_shapes(params_f))\n",
        "print('g parameters:', parameter_shapes(params_g))\n",
        "\n",
        "# Transform both functions at once with hk.multi_transform , and print the\n",
        "# resulting merged parameter structure.\n",
        "\n",
        "def multitransform_f_and_g():\n",
        "  def template(x):\n",
        "    return f(x), g(x)\n",
        "  return template, (f, g)\n",
        "init, (f_apply, g_apply) = hk.multi_transform(multitransform_f_and_g)\n",
        "merged_params = init(rng, x)\n",
        "\n",
        "print('\\nThe name and shape of the multi-transform parameters are:\\n',\n",
        "      parameter_shapes(merged_params))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "chOb1ewKcAw1"
      },
      "source": [
        "In this example `f` and `g` both instantiate a `SimpleModule` instance with the same arguments, and if we transform them separately we see that both dictionaries contain a `'simple_module'` key.\n",
        "\n",
        "When we transform them together instead, `hk.multi_transform` takes care for us of renaming one of them to `'simple_module_1'`, thus preventing accidental parameter sharing."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AW2ikapTcAw1"
      },
      "source": [
        "## Sharing parameters between transformed functions\n",
        "Now that we understood how module names are assigned and how this affects parameter sharing, let's see how we can share parameters between transformed functions.\n",
        "\n",
        "In this section we will consider two functions, `f` and `g`, and explore different strategies to share parameters. We will consider a number of cases that differ in how many of the modules instantiated by each function are the same, and if their parameters have the same shape."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-qP24cCmcAw1"
      },
      "source": [
        "### Case 1: All modules have the same names, and the same shape\n",
        "Let's reuse one of the modules we created before, and try to instantiate it twice inside two different functions:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wB9fLqtFcAw1",
        "outputId": "fd13dca9-48ec-4c28-9762-952955e6a887"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "f parameters: {'simple_module': {'w': (3, 2)}}\n",
            "g parameters: {'simple_module': {'w': (3, 2)}}\n"
          ]
        }
      ],
      "source": [
        "def f(x):\n",
        "  \"\"\"Apply SimpleModule to x.\"\"\"\n",
        "  module_instance = SimpleModule(output_channels=2)\n",
        "  out = module_instance(x)\n",
        "  return out\n",
        "\n",
        "def g(x):\n",
        "  \"\"\"Like f, but double the output\"\"\"\n",
        "  module_instance = SimpleModule(output_channels=2)\n",
        "  out = module_instance(x)\n",
        "  return out * 2\n",
        "\n",
        "# Transform both functions, and print the parameter shapes.\n",
        "rng = jax.random.PRNGKey(42)\n",
        "x = jnp.ones((2, 3))\n",
        "\n",
        "transformed_f = hk.transform(f)\n",
        "params_f = transformed_f.init(rng, x)\n",
        "transformed_g = hk.transform(g)\n",
        "params_g = transformed_g.init(rng, x)\n",
        "\n",
        "print('f parameters:', parameter_shapes(params_f))\n",
        "print('g parameters:', parameter_shapes(params_g))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NSGGtbLFcAw1"
      },
      "source": [
        "Great! Since `f` and `g` are using exactly the same modules, the sets of initialized variables generated with each have the same *name structure* (note that the actual values might differ, depending on initialization).\n",
        "\n",
        "Now, if we wanted to share parameters in this case, we could initialize only one of the two functions (e.g., `f`) and use the resulting parameters for both functions, i.e., when we call `transformed_f.apply` and `transformed_g.apply`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ztSEsR8XcAw1"
      },
      "source": [
        "### Case 2: Common modules have the same names, and the same shape\n",
        "This is a nice trick, but what if the functions were not identical? Let's build two such functions:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TmkZ7T1pcAw1",
        "outputId": "cf98f7ac-58c6-4387-9486-280a63be5209"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "The name and shape of the f parameters are:\n",
            " {'linear': {'b': (40,), 'w': (2, 40)}, 'simple_module': {'w': (3, 2)}}\n",
            "\n",
            "The name and shape of the g parameters are:\n",
            " {'mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'mlp/~/linear_1': {'b': (40,), 'w': (10, 40)}, 'simple_module': {'w': (3, 2)}}\n"
          ]
        }
      ],
      "source": [
        "def f(x):\n",
        "  \"\"\"A SimpleModule followed by a Linear layer.\"\"\"\n",
        "  module_instance = SimpleModule(output_channels=2)\n",
        "  out = module_instance(x)\n",
        "  linear = hk.Linear(40)\n",
        "  return linear(out)\n",
        "\n",
        "def g(x):\n",
        "  \"\"\"A SimpleModule followed by an MLP.\"\"\"\n",
        "  module_instance = SimpleModule(output_channels=2)\n",
        "  out = module_instance(x)\n",
        "  linear = hk.nets.MLP((10, 40))\n",
        "  return linear(out)\n",
        "\n",
        "# Transform both functions, and print the parameter shapes.\n",
        "rng = jax.random.PRNGKey(42)\n",
        "x = jnp.ones((2, 3))\n",
        "\n",
        "transformed_f = hk.transform(f)\n",
        "params_f = transformed_f.init(rng, x)\n",
        "transformed_g = hk.transform(g)\n",
        "params_g = transformed_g.init(rng, x)\n",
        "\n",
        "print('\\nThe name and shape of the f parameters are:\\n',\n",
        "      parameter_shapes(params_f))\n",
        "print('\\nThe name and shape of the g parameters are:\\n',\n",
        "      parameter_shapes(params_g))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2dCELn97cAw1"
      },
      "source": [
        "Now we have a problem! Both sets of parameters have a `'simple_module'` component, but they also each contain parameters that are specific only to that function, so we cannot simply initialise only one of the functions and use the returned parameters for both as we did before. But we would still like to share the parameters of `'simple_module'`. How can we do that?\n",
        "\n",
        "One option here is to use [`haiku.data_structures.merge`](https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.data_structures.merge) to combine the two sets of parameters. This will merge the two structures, keeping only the value from the last structure when both structures have the same parameters (i.e., `'simple_module'` in our example). Let try that:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SOS6azFwcAw1",
        "outputId": "4912e15b-da15-42ed-82aa-1575e40494a1"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "The name and shape of the shared parameters are:\n",
            " {'linear': {'b': (40,), 'w': (2, 40)}, 'mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'mlp/~/linear_1': {'b': (40,), 'w': (10, 40)}, 'simple_module': {'w': (3, 2)}}\n"
          ]
        }
      ],
      "source": [
        "merged_params = hk.data_structures.merge(params_f, params_g)\n",
        "print('\\nThe name and shape of the shared parameters are:\\n',\n",
        "      parameter_shapes(merged_params))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-0Pst1wncAw1"
      },
      "source": [
        "Brilliant! Now we have a shared set of parameters that contains all the disjoint parameters and a single set of parameters for the shared `'simple_module'`. Let's verify that we can use this set of parameters when calling either function:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "582UosnTcAw1",
        "outputId": "66a9053c-183b-4544-a538-abbbc2f58513"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "f_out mean: 0.037986994\n",
            "g_out mean: 0.104857825\n"
          ]
        }
      ],
      "source": [
        "f_out = transformed_f.apply(merged_params, rng, x)\n",
        "g_out = transformed_g.apply(merged_params, rng, x)\n",
        "\n",
        "print('f_out mean:', f_out.mean())\n",
        "print('g_out mean:', g_out.mean())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x2f5t10RcAw1"
      },
      "source": [
        "This gives us little control over what gets shared though: what if the two functions had parameters with the same name that we don't want to share?\n",
        "\n",
        "### Case 3: Common modules have the same names, but different shapes\n",
        "Let's modify our previous example to use a `hk.Linear` layer in both functions:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WLFwu1_9cAw1",
        "outputId": "ff725c0d-cbf4-4a60-9a8b-dd37b3540a8b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "The name and shape of the f parameters are:\n",
            " {'linear': {'b': (4,), 'w': (5, 4)}, 'mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'simple_module': {'w': (3, 2)}}\n",
            "\n",
            "The name and shape of the g parameters are:\n",
            " {'linear': {'b': (20,), 'w': (5, 20)}, 'mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'simple_module': {'w': (3, 2)}}\n"
          ]
        }
      ],
      "source": [
        "def f(x):\n",
        "  \"\"\"A SimpleModule followed by two Linear layers.\"\"\"\n",
        "  module_instance = SimpleModule(output_channels=2)\n",
        "  out = module_instance(x)\n",
        "  mlp = hk.nets.MLP((10, 5))\n",
        "  out = mlp(out)\n",
        "  last_linear = hk.Linear(4)\n",
        "  return last_linear(out)\n",
        "\n",
        "def g(x):\n",
        "  \"\"\"Same as f, with a bigger final layer.\"\"\"\n",
        "  module_instance = SimpleModule(output_channels=2)\n",
        "  out = module_instance(x)\n",
        "  mlp = hk.nets.MLP((10, 5))\n",
        "  out = mlp(out)\n",
        "  last_linear = hk.Linear(20)  # another Linear, but bigger\n",
        "  return last_linear(out)\n",
        "\n",
        "# Transform both functions, and print the parameter shapes.\n",
        "rng = jax.random.PRNGKey(42)\n",
        "x = jnp.ones((2, 3))\n",
        "\n",
        "transformed_f = hk.transform(f)\n",
        "params_f = transformed_f.init(rng, x)\n",
        "transformed_g = hk.transform(g)\n",
        "params_g = transformed_g.init(rng, x)\n",
        "\n",
        "print('\\nThe name and shape of the f parameters are:\\n',\n",
        "      parameter_shapes(params_f))\n",
        "print('\\nThe name and shape of the g parameters are:\\n',\n",
        "      parameter_shapes(params_g))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AdrFlwhEcAw2"
      },
      "source": [
        "Now we have a problem! Both sets of parameters have a `'linear'` component, but their respective parameters have different shapes. If we merged them as we did before, the parameters of the `'linear'` from `f` would be dropped and we couldn't use the merged parameters to call it:\n",
        "\n",
        "```python\n",
        "merged_params = hk.data_structures.merge(params_f, params_g)\n",
        "print('\\nThe name and shape of the merged parameters are:\\n',\n",
        "      parameter_shapes(merged_params))\n",
        "\n",
        "f_out = transformed_f.apply(merged_params, rng, x)  # fails\n",
        "# ValueError: 'linear/w' with retrieved shape (5, 20) does not match shape=[5, 4] dtype=dtype('float32')\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LLikYl1-cAw2"
      },
      "source": [
        "How can we share the parameters of `'simple_module'` and `mlp`, but keep the parameters of the two output `linear` layers separated?\n",
        "\n",
        "A solution would to instantiate `simple_module` and `mlp` outside of the functions, so that they get instantiated only once, and then use that instance in both functions. But all Haiku modules must be initialised in a transform, so doing so naively would incur in an error:\n",
        "\n",
        "```python\n",
        "module_instance = SimpleModule(output_channels=2)  # this fails\n",
        "# ValueError: All `hk.Module`s must be initialized inside an `hk.transform`.\n",
        "mlp = hk.nets.MLP((10, 5))\n",
        "\n",
        "def f(x):\n",
        "  \"\"\"A SimpleModule followed by a Linear layer.\"\"\"\n",
        "  out = module_instance(x)\n",
        "  out = mlp(out)\n",
        "  linear = hk.Linear(4)\n",
        "  return linear(out)\n",
        "\n",
        "def g(x):\n",
        "  \"\"\"A SimpleModule followed by a bigger Linear layer.\"\"\"\n",
        "  out = module_instance(x)\n",
        "  out = mlp(out)\n",
        "  linear = hk.Linear(20)  # another Linear, but bigger\n",
        "  return linear(out)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ApVnyyuccAw2"
      },
      "source": [
        "We can work around that by creating another function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PcPUkXDEcAw2",
        "outputId": "cf3c97a7-d80f-4eb1-b62b-bf80c408d558"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "The name and shape of the f parameters are:\n",
            " {'linear': {'b': (4,), 'w': (5, 4)}, 'mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'simple_module': {'w': (3, 2)}}\n",
            "\n",
            "The name and shape of the g parameters are:\n",
            " {'linear': {'b': (20,), 'w': (5, 20)}, 'mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'simple_module': {'w': (3, 2)}}\n",
            "\n",
            "The MLP parameters are shared!\n"
          ]
        }
      ],
      "source": [
        "class CachedModule():\n",
        "\n",
        "  def __call__(self, *inputs):\n",
        "    # Create the instances if are not in the cache.\n",
        "    if not hasattr(self, 'cached_simple_module'):\n",
        "      self.cached_simple_module = SimpleModule(output_channels=2)\n",
        "    if not hasattr(self, 'cached_mlp'):\n",
        "      self.cached_mlp = hk.nets.MLP((10, 5))\n",
        "\n",
        "    # Apply the cached instances.\n",
        "    out = self.cached_simple_module(*inputs)\n",
        "    out = self.cached_mlp(out)\n",
        "    return out\n",
        "\n",
        "\n",
        "def f(x):\n",
        "  \"\"\"A SimpleModule followed by a Linear layer.\"\"\"\n",
        "  shared_preprocessing = CachedModule()\n",
        "  out = shared_preprocessing(x)\n",
        "  linear = hk.Linear(4)\n",
        "  return linear(out)\n",
        "\n",
        "def g(x):\n",
        "  \"\"\"A SimpleModule followed by a bigger Linear layer.\"\"\"\n",
        "  shared_preprocessing = CachedModule()\n",
        "  out = shared_preprocessing(x)\n",
        "  linear = hk.Linear(20)  # another Linear, but bigger\n",
        "  return linear(out)\n",
        "\n",
        "\n",
        "# Transform both functions, and print the parameter shapes.\n",
        "rng = jax.random.PRNGKey(42)\n",
        "x = jnp.ones((2, 3))\n",
        "\n",
        "transformed_f = hk.transform(f)\n",
        "params_f = transformed_f.init(rng, x)\n",
        "transformed_g = hk.transform(g)\n",
        "params_g = transformed_g.init(rng, x)\n",
        "\n",
        "print('\\nThe name and shape of the f parameters are:\\n',\n",
        "      parameter_shapes(params_f))\n",
        "print('\\nThe name and shape of the g parameters are:\\n',\n",
        "      parameter_shapes(params_g))\n",
        "\n",
        "# Verify that the simple module parameters are shared.\n",
        "assert_all_equal(params_f['mlp/~/linear_0'],\n",
        "                 params_g['mlp/~/linear_0'])\n",
        "assert_all_equal(params_f['mlp/~/linear_1'],\n",
        "                 params_g['mlp/~/linear_1'])\n",
        "print('\\nThe MLP parameters are shared!')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8Lj4v7E7cAw2"
      },
      "source": [
        "If we want to share a big number of modules it can become tedious to cache each one of them manually inside of `CachedModule`. Furthermore, it would be nice if we didn't have to define a different `CachedModule` object for every function we want to cache.\n",
        "\n",
        "We can use `hk.to_module` to create a more general `CachedModule` object that takes an arbitrary Haiku function and caches it:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "baIBTu2qcAw2",
        "outputId": "63817f7a-5906-4c6f-ab05-fb37c5a0bac7"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "The name and shape of the f parameters are:\n",
            " {'linear': {'b': (4,), 'w': (5, 4)}, 'shared_preprocessing_fn/mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'shared_preprocessing_fn/mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'shared_preprocessing_fn/simple_module': {'w': (3, 2)}}\n",
            "\n",
            "The name and shape of the g parameters are:\n",
            " {'linear': {'b': (20,), 'w': (5, 20)}, 'shared_preprocessing_fn/mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'shared_preprocessing_fn/mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'shared_preprocessing_fn/simple_module': {'w': (3, 2)}}\n",
            "\n",
            "The MLP parameters are shared!\n"
          ]
        }
      ],
      "source": [
        "class CachedModule():\n",
        "  \"\"\"Cache one instance of the function and call it multiple times.\"\"\"\n",
        "  def __init__(self, fn):\n",
        "    self._fn = fn\n",
        "\n",
        "  def __call__(self, *args, **kwargs):\n",
        "    if not hasattr(self, \"_instance\"):\n",
        "      ModularisedFn = hk.to_module(self._fn)\n",
        "      self._instance = ModularisedFn()\n",
        "    return self._instance(*args, **kwargs)\n",
        "\n",
        "def shared_preprocessing_fn(x):\n",
        "  simple_module = SimpleModule(output_channels=2)\n",
        "  out = simple_module(x)\n",
        "  mlp = hk.nets.MLP((10, 5))\n",
        "  return mlp(out)\n",
        "\n",
        "def f(x):\n",
        "  \"\"\"A SimpleModule followed by a Linear layer.\"\"\"\n",
        "  shared_preprocessing = CachedModule(shared_preprocessing_fn)\n",
        "  out = shared_preprocessing(x)\n",
        "  linear = hk.Linear(4)\n",
        "  return linear(out)\n",
        "\n",
        "def g(x):\n",
        "  \"\"\"A SimpleModule followed by a bigger Linear layer.\"\"\"\n",
        "  shared_preprocessing = CachedModule(shared_preprocessing_fn)\n",
        "  out = shared_preprocessing(x)\n",
        "  linear = hk.Linear(20)  # another Linear, but bigger\n",
        "  return linear(out)\n",
        "\n",
        "\n",
        "# Transform both functions, and print the parameter shapes.\n",
        "rng = jax.random.PRNGKey(42)\n",
        "x = jnp.ones((2, 3))\n",
        "\n",
        "transformed_f = hk.transform(f)\n",
        "params_f = transformed_f.init(rng, x)\n",
        "transformed_g = hk.transform(g)\n",
        "params_g = transformed_g.init(rng, x)\n",
        "\n",
        "print('\\nThe name and shape of the f parameters are:\\n',\n",
        "      parameter_shapes(params_f))\n",
        "print('\\nThe name and shape of the g parameters are:\\n',\n",
        "      parameter_shapes(params_g))\n",
        "\n",
        "# Verify that the simple module parameters are shared.\n",
        "assert_all_equal(params_f['shared_preprocessing_fn/mlp/~/linear_0'],\n",
        "                 params_g['shared_preprocessing_fn/mlp/~/linear_0'])\n",
        "assert_all_equal(params_f['shared_preprocessing_fn/mlp/~/linear_1'],\n",
        "                 params_g['shared_preprocessing_fn/mlp/~/linear_1'])\n",
        "print('\\nThe MLP parameters are shared!')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1Rpb5VrKcAw2"
      },
      "source": [
        "When we work with objects it can also be convenient to define a decorator to do the same:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "executionInfo": {
          "elapsed": 2449,
          "status": "ok",
          "timestamp": 1694690980671,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "hvhwCHFdcAw2",
        "outputId": "f74be564-0b3e-4a6f-8cb0-f5af4adb4192"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "The name and shape of the f parameters are:\n",
            " {'linear': {'b': (4,), 'w': (5, 4)}, 'shared_preprocessing/mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'shared_preprocessing/mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'shared_preprocessing/simple_module': {'w': (3, 2)}}\n",
            "\n",
            "The name and shape of the g parameters are:\n",
            " {'linear': {'b': (20,), 'w': (5, 20)}, 'shared_preprocessing/mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'shared_preprocessing/mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'shared_preprocessing/simple_module': {'w': (3, 2)}}\n",
            "\n",
            "The MLP parameters are shared!\n"
          ]
        }
      ],
      "source": [
        "def share_parameters():\n",
        "  def decorator(fn):\n",
        "    def wrapper(*args, **kwargs):\n",
        "      if wrapper.instance is None:\n",
        "        wrapper.instance = hk.to_module(fn)()\n",
        "      return wrapper.instance(*args, **kwargs)\n",
        "    wrapper.instance = None\n",
        "    return functools.wraps(fn)(wrapper)\n",
        "  return decorator\n",
        "\n",
        "\n",
        "class Wrapper():\n",
        "\n",
        "  @share_parameters()\n",
        "  def shared_preprocessing(self, x):\n",
        "    simple_module = SimpleModule(output_channels=2)\n",
        "    out = simple_module(x)\n",
        "    mlp = hk.nets.MLP((10, 5))\n",
        "    return mlp(out)\n",
        "\n",
        "  def f(self, x):\n",
        "    \"\"\"A SimpleModule followed by a Linear layer.\"\"\"\n",
        "    out = self.shared_preprocessing(x)\n",
        "    linear = hk.Linear(4)\n",
        "    return linear(out)\n",
        "\n",
        "  def g(self, x):\n",
        "    \"\"\"A SimpleModule followed by a bigger Linear layer.\"\"\"\n",
        "    out = self.shared_preprocessing(x)\n",
        "    linear = hk.Linear(20)  # another Linear, but bigger\n",
        "    return linear(out)\n",
        "\n",
        "# Transform both functions, and print the parameter shapes.\n",
        "rng = jax.random.PRNGKey(42)\n",
        "x = jnp.ones((2, 3))\n",
        "\n",
        "wrapper = Wrapper()\n",
        "transformed_f = hk.transform(wrapper.f)\n",
        "params_f = transformed_f.init(rng, x)\n",
        "transformed_g = hk.transform(wrapper.g)\n",
        "params_g = transformed_g.init(rng, x)\n",
        "\n",
        "print('\\nThe name and shape of the f parameters are:\\n',\n",
        "      parameter_shapes(params_f))\n",
        "print('\\nThe name and shape of the g parameters are:\\n',\n",
        "      parameter_shapes(params_g))\n",
        "\n",
        "# Verify that the simple module parameters are shared.\n",
        "assert_all_equal(params_f['shared_preprocessing/mlp/~/linear_0'],\n",
        "                 params_g['shared_preprocessing/mlp/~/linear_0'])\n",
        "assert_all_equal(params_f['shared_preprocessing/mlp/~/linear_1'],\n",
        "                 params_g['shared_preprocessing/mlp/~/linear_1'])\n",
        "print('\\nThe MLP parameters are shared!')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3",
        "kind": "private"
      },
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
